Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Oct 7, 2025

Description

Motivation:

  • FSDP2 training currently doesn't work with model initialized with fp8 weights. And if high precision weights are used with TE layers, the memory consumed by the model is more than what the model would consume with BF16 when te auto-cast is used, making it difficult to adopt TE for fp8 based fsdp2 training(issue). Hence it will be useful to get FSDP2 to work with FP8 initialized weights(issue).
  • Along with fixing the memory usage for model initialized with FP8 weight tensors we also want FSDP2 to actually work in terms of the FP8 tensors getting updated correctly after every training step. Current behavior is the Float8Tensors for weights dont get updated. This is not just specific to FSDP but also to DDP with fp8 initialized weights.issue
  • We also want the FSDP weight allgather to use FP8 instead of a high precision allgather for efficient training performance. Currently in TE for fp8 initialized weights, allgather happens in high precision.(issue).

What this PR does?

  • Enables FSDP2 based model training EtoE for any pytorch model with TE layers and FP8 initialized weights
    • Solves the memory foot-print issue with FP8 initialized weights. Initialization with FP8(per-tensor scaling) on balckwell takes half the memory footprint compared to BF16 which is expected. MXFP8 and BF16 consume the same amount of memory due to both rowwise/columnwise usages needed in case of MXFP8.
    • Fixes the FP8 weight updates when model is initialized with FP8 weights to ensure correctness of training results
    • Enables 8bit weight Allgather for both FP8/MXFP8 tensors.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • FSDP Allgather Hooks for FP8/MXFP8: Adds fsdp_pre_allgather and fsdp_post_allgather methods for for FP8/MXFP8 tensors, since allgather is only supported for native torch tensors with uint8/fp16/bf16/fp32 data types. fsdp_pre_all_gather method for us would return the uint8 sharded tensors for FP8/MXFP8 that we need to allgather and the metadata that is needed to reconstruct the FP8/MXFP8 tensor post allgather. Post_Allgather reconstructs the Float8/MXFP8 tensor from the allgathered uint8 data.

    • Handling quantization usages in allgather: Assumption here is that fsdp_pre_all_gather and post_all_gather methods are only going to be called for the weight tensors, which is a fair assumption since fsdp is only used to shard the weights. Which means that we would be using rowwise usage for the forward pass and columnwise usage for the backward pass.
    • Identifying forward/backward pass during allgather: This is needed since just one to rowwise/columnise usages need to be allgathered based on whether it is a forward/backward pass of the training step. fsdp_pre_all_gather method passes module as an argument which is essentially nn Module that has the Quantized tensor registered as a parameter. This module might not necessarily be an FSDP module since we might be wrapping the module at a much higher level in the heirarchy(For eg TransformerLayer and not wrapping the submodule Linear). Hence we have a method that computes the lowest common ancestor FSDP module and uses that to get FSDP state which has the information as to whether it is a forward or backward pass. NOTE: The return value is cached with lru_cache since we dont want to call during every iteration/allgather done during training. The return value is a reference which is mutated internally by FSDP during the course of training.
    • Reshard After Forward: FSDP2 allows for a configuration that tells whether the parameters need to be resharded after forward pass (meaning weights will be re-allgathered again for backward pass). By default, this configuration is set to False for the root module and True for submodules. This configuration is obtainable from the FSDP state of the module , the parameter belongs to. And is used to determine whether we need to send both rowwise/columnwise data in one-go or just one of them based on forward/backward pass. This is more important in MXFP8 since we might want to send both the usages, instead of sending just one usage, dequantizing and quantizing back to get all necessary usages(leading to quantization errors).
    • Current Scaling Quantization: In case of Current Scaling quantization, we need to make sure there is one single amax/scale inverse being used across all the shards which is going to be true when the model is initialized. However, each quantized weight shard is updated independently by the optimizer during training. And hence we need to set amax reduction group in quantizer if not already set. And so this is done in the allgather of forward pass itself(by utilizing fsdp mesh information), so that when the weight shard is updated, quantizer is going to synchronize among the shards to compute a single amax and hence make sure each weight shard uses the same scale inverse.
  • FP8/MXFP8 Torch Dispatch Functions for FSDP2 to handle ops on both rowwise/columnwise data(MXFP8), data/transpose(FP8). NOTE(MXFP8 tensors without padding requirements are only handled. If padding is needed we down the dequantization-compute-quantization route).

    • Split Function If the model is initialized on CUDA device at the start, torch chunk/split is called on our custom Quantized tensors to split tensors along dimension 0. At the end of split FSDP2 keeps the split/shard that is needed for that process and discards everything else to free memory before model training. NOTE: In case of meta deferred initialization this method isnt called. And quantized tensors are directly instantiated for the weight shard corresponding to the process rather than initializing everything and discarding the shards not needed.
    • new_zeros: Implementing this function will make sure a new tensor is created with shape that the shard is supposed to be of. Original implementation in Float8Tensor dint create a deep copy for the scale inverses. That is fixed now.
    • copy: Splitted/Sharded tensor is then copied to the zero tensor created above.
    • as_strided: FSDP2 allows for a possibility where one of the shards might have fewer elements than the other shard if split dimension 0 has number of elements not divisible by the number of shards. It pads the smaller shard. And hence calls as_strided API after allgather to remove the padding. Currently we dont handle the case where divisibilty condition is not met(would be complicated for mxfp8 and beyond scope of this PR) and hence as_strided API is essentially a no-op for us.
    • view: In FSDP2, sharded parameters are flattened with view and that is used to allgather when compiled autograd is enabled. However, for MXFP8 we throw an error if we flatten the tensor since the last dimension of MXFP8 should never change. Currently in that case, we are enabling the dequantization followed by high precision view path, so that FSDP2 doesnt fail. However, we raise a warning when that happens. This is not concern for us at the moment since we dont use compiled autograd and so this view is essentially not even used.
  • Quantized Tensor Class Issues:

    • Missing Dequantize/Compute/Quantize Pathway: When optimizer is applied on FP8/MXFP8 weights, optimizer sends the optimizer ops(lerp for weight update) on a list of Float8 Weights instead of individually doing an op on each Float8 weight seperately. Our normal dequantize/Compute Op/Quantize route didnt handle a list of Float8 Tensors and so, weights were not getting updated in place. PR fixes this.
    • make_like API relying on data Attribute: make_like API in Quantized tensor class should not be setting data attribute since that is specific to Float8Tensor. So that setting logic is moved to Float8Tensor class instead.
  • Validating rowwise/columnwise Usages for quantizers/tensors in TE Layers

    • Weight Tensor Usage Validation: Currently we validate the presence of all desired rowwise/columnwise usages for weight tensors in the forward pass of our Layers itself. However in case of FSDP2, different usages are allgathered in forward and in backward pass. So validation of appropriate quantization usages are moved to forward and backward functions of the layers i.e rowwise usage is needed in forward and columnwise usage is needed in backward.
    • Quantizer Usage Validation: We also update the weight quantizer even when weights are already in FP8. If weights are already in FP8, there is no need to update the quantizer since the damage is already done and that quantizer is never going to be used. And hence this update is now removed from the code.
  • Resetting Parameters for Deferred Initialization(meta device)

    • Updating Dtensors instead of regular tensor: In case of deferred initialization with FSDP2. Parameters are Dtensors that just hold unmaterialized shard needed by the process. And so the local tensor of Dtensor needs to be updated with quantized weights initialized with param_init_fn.
    • Current scaling quantization: For this case, amax reduction group needs to passed to the quantizer so that all weight shards initialized share a single scale inverse.
  • Test and Miscellaneous issues

    • More complete Test Cases for FSDP2: Originally the test only enabled to test a linear layer. Now we can test it with model created with different TE layers. And tests for combinations with and without fp8 model init and different quantization recipes(fp8/mxfp8). NOTE: NVFP4 is pending.
    • View and Reshape not handling Columnwise elegantly In case the columnwise data is present and is accurate, view and reshape ops are now also performed on the transpose(FP8)/columnwise-data(MXFP8) instead of invalidating them.
    • Float8 make_empty API: For make_empty if transpose is desired, shape of transpose created originally was (shape[-1), math.prod(shape[:-1])). Now made it consistent with the transpose shapes we create in C++ which is essentially (shape[-1], shape[0], shape[1]....shape[-2]). This is needed since, we are handling transpose ops in the torch dispatch needed for FSDP2 and we need to be consistent everywhere.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Summary by CodeRabbit

Release Notes

  • New Features
    • Added FP8 mixed-precision training support with FSDP2/HSDP distributed sharding.
    • Introduced multiple FP8 quantization scaling recipes: delayed scaling, current scaling, and MX_FP8 block scaling.
    • Expanded distributed training configuration options: batch size, sequence length, data type, layer configuration, number of layers, device placement, and sharding specification.
    • Improved distributed tensor parameter support and synchronization for FSDP integration.

@vthumbe1503 vthumbe1503 changed the title FSDP2 Weight Update Fix [Pytorch] FSDP2 Weight Update Fix Oct 8, 2025
@vthumbe1503 vthumbe1503 changed the title [Pytorch] FSDP2 Weight Update Fix [PyTorch] FSDP2 Weight Update Fix Oct 8, 2025
@vthumbe1503 vthumbe1503 changed the title [PyTorch] FSDP2 Weight Update Fix [PyTorch] TE FSDP2 Support for FP8/MXFP8 Oct 17, 2025
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds FSDP2 (Fully Sharded Data Parallel 2) support to Transformer Engine, enabling distributed training with FP8 quantization. The implementation includes:

  • Core Integration: Added fsdp_pre_all_gather() and fsdp_post_all_gather() hooks to Float8Tensor and MXFP8Tensor for weight sharding/gathering during forward/backward passes
  • Tensor Operations: Implemented FSDP2-compatible __torch_dispatch__ handlers for split, slice, as_strided, copy_, and new_zeros operations with proper handling of quantized data and transpose caches
  • DTensor Support: Enhanced TransformerEngineBaseModule.reset_parameters() to handle DTensor parameters, including proper quantizer configuration with amax reduction across device mesh
  • Test Coverage: Expanded test suite with multiple FP8 scaling recipes (delayed, current, MX_FP8) and sharding configurations

Key Implementation Details:

  • Forward pass uses rowwise data representation; backward pass uses columnwise (transpose) for optimal Tensor Core performance
  • Transpose cache is maintained across tensor operations to avoid recomputation
  • Amax reduction is configured across FSDP mesh for consistent scaling across shards

Issues from Previous Comments:
Several critical issues from earlier reviews remain unresolved in mxfp8_tensor.py, particularly around None handling in the slice.Tensor dispatch handler (line 479) where out_data[0].shape assumes rowwise_data exists.

Confidence Score: 3/5

  • This PR requires fixes before merging - critical None handling issues remain from previous reviews
  • Score reflects unresolved critical issues from previous review comments. The slice.Tensor handler at mxfp8_tensor.py:479 accesses out_data[0].shape but out_data[0] can be None when rowwise_data is None, causing AttributeError. Similar patterns exist in other handlers. The Float8Tensor implementation appears more robust with better None handling. Core FSDP2 integration logic is sound but needs defensive programming for edge cases.
  • Pay close attention to transformer_engine/pytorch/tensor/mxfp8_tensor.py - the slice.Tensor, split.Tensor, and as_strided handlers need None-safety fixes for cases where rowwise_data or columnwise_data is None

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Added FSDP2 dispatch handlers (split, as_strided, copy_, slice, new_zeros) and pre/post all-gather hooks. Previous comments flagged None handling and AttributeError risks that remain unaddressed.
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Added FSDP2 support with transpose cache handling in split/new_zeros/as_strided ops, plus pre/post all-gather hooks for distributed training. Good transpose cache maintenance logic.
transformer_engine/pytorch/module/base.py 4/5 Enhanced reset_parameters to handle DTensor (FSDP2), preserving local tensor quantization while maintaining distributed mesh metadata. Properly handles amax reduction groups.
transformer_engine/pytorch/distributed.py 3/5 Added _get_module_fsdp_state helper with @lru_cache. Previous comment noted potential stale cache issues during training state changes.
tests/pytorch/distributed/run_fsdp2_model.py 4/5 Expanded FSDP2 test runner with multiple scaling recipes (delayed, current, MX_FP8), sharding configurations, and distributed test cases.

Sequence Diagram

sequenceDiagram
    participant User as Training Script
    participant FSDP as FSDP2 Manager
    participant Module as TE Module
    participant FP8 as Float8Tensor/MXFP8Tensor
    participant Quantizer as Quantizer
    
    Note over User,Quantizer: Initialization Phase
    User->>Module: fully_shard(te_module)
    Module->>Module: register_parameter()
    Note over Module: Skip param_init_meta if already exists
    
    User->>Module: reset_parameters()
    Module->>Module: Check if param is DTensor
    alt Is DTensor
        Module->>Module: Extract _local_tensor
        Module->>Quantizer: Configure amax_reduction_group
        Module->>FP8: Quantize local tensor
        Module->>Module: Wrap back to DTensor
    else Regular Tensor
        Module->>FP8: Quantize tensor
    end
    
    Note over User,Quantizer: Forward Pass
    FSDP->>FP8: fsdp_pre_all_gather(mesh, orig_size, module, mp_policy)
    FP8->>Quantizer: Configure rowwise/columnwise usage
    Note over FP8: Set usage based on training state
    FP8->>FSDP: Return (sharded_tensors, metadata)
    
    FSDP->>FSDP: AllGather sharded_tensors
    
    FSDP->>FP8: fsdp_post_all_gather(all_gather_outputs, metadata, param_dtype)
    FP8->>FP8: Reconstruct from gathered data
    FP8->>FSDP: Return reconstructed Float8Tensor
    
    FSDP->>Module: forward() with all-gathered weights
    
    Note over User,Quantizer: Backward Pass
    FSDP->>FP8: fsdp_pre_all_gather() [backward]
    Note over FP8: Configure columnwise usage for backward
    FP8->>FSDP: Return (sharded_tensors, metadata)
    
    FSDP->>FSDP: AllGather for backward
    
    FSDP->>FP8: fsdp_post_all_gather()
    FP8->>FSDP: Return reconstructed tensor
    
    Note over User,Quantizer: Gradient Sync
    FSDP->>FSDP: ReduceScatter gradients
    FSDP->>Module: Update sharded weights
Loading

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds comprehensive FSDP2 (Fully Sharded Data Parallel v2) support to Transformer Engine, enabling distributed training with FP8 quantization across multiple scaling recipes.

Key Changes:

  • Implemented fsdp_pre_all_gather() and fsdp_post_all_gather() hooks in Float8Tensor and MXFP8Tensor for FSDP2 integration
  • Added torch dispatch handlers for FSDP2 tensor operations (split, as_strided, slice, copy_, new_zeros)
  • Enhanced transpose cache management during tensor reshaping operations for improved performance
  • Added training state-aware quantizer usage selection (rowwise for forward, columnwise for backward)
  • Modified TE modules to detect pre-quantized weights and skip redundant quantizer configuration
  • Expanded test coverage with multiple recipes (delayed scaling, current scaling, MX block scaling) and layer types

Architecture:
FSDP2 shards quantized tensors across ranks. During forward/backward, fsdp_pre_all_gather() extracts sharded data and metadata, FSDP2 performs all-gather, then fsdp_post_all_gather() reconstructs the full quantized tensor with proper transpose caches based on training state.

Confidence Score: 4/5

  • Safe to merge with minor concerns about LRU cache behavior in distributed settings
  • The implementation is well-structured with proper FSDP2 integration patterns. The main concern is the @lru_cache decorator on _get_module_fsdp_state() which could potentially cache stale FSDP state if module state changes during training. The core tensor operations, quantizer handling, and test coverage are solid. No critical bugs identified, though the cache issue warrants monitoring in production.
  • transformer_engine/pytorch/distributed.py - monitor _get_module_fsdp_state() LRU cache behavior during resharding operations

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 support with pre/post all-gather hooks, improves torch dispatch for view/split/new_zeros/as_strided ops with transpose cache handling
transformer_engine/pytorch/tensor/mxfp8_tensor.py 4/5 Implements FSDP2 support and torch dispatch handlers for split/as_strided/copy_/slice operations, adds pre/post all-gather hooks for MX format
transformer_engine/pytorch/distributed.py 3/5 Adds _get_module_fsdp_state helper with LRU cache for retrieving FSDP state from modules - cache may cause stale state issues
transformer_engine/pytorch/module/base.py 4/5 Improves DTensor parameter handling in weight initialization, adds FSDP group support to workspace management
transformer_engine/pytorch/module/linear.py 5/5 Skips quantizer configuration when weight is already quantized (FSDP2 pre-quantized weights), removes redundant columnwise usage update

Sequence Diagram

sequenceDiagram
    participant App as Application
    participant FSDP as FSDP2
    participant TEModule as TE Module
    participant QTensor as Quantized Tensor<br/>(Float8/MXFP8)
    participant Quantizer as Quantizer

    Note over App,Quantizer: Initialization Phase
    App->>TEModule: Create with FP8 init
    TEModule->>Quantizer: Create quantizer (Delayed/Current/MX)
    TEModule->>QTensor: Quantize weights
    App->>FSDP: fully_shard(module)
    FSDP->>QTensor: Shard weights

    Note over App,Quantizer: Forward Pass (Training)
    App->>FSDP: forward()
    FSDP->>QTensor: fsdp_pre_all_gather()
    QTensor->>QTensor: Set usage based on training state
    QTensor->>FSDP: Return (sharded_tensors, metadata)
    FSDP->>FSDP: All-gather sharded tensors
    FSDP->>QTensor: fsdp_post_all_gather(outputs, metadata)
    QTensor->>QTensor: Reconstruct full tensor
    QTensor->>QTensor: update_usage() for transpose cache
    QTensor->>FSDP: Return allgathered QTensor
    FSDP->>TEModule: Forward with full weights
    TEModule->>TEModule: Check if weight is QuantizedTensor
    alt Weight already quantized
        TEModule->>TEModule: Skip quantizer setup
        TEModule->>QTensor: Use existing quantizer
    else Weight not quantized
        TEModule->>Quantizer: Set usage flags
        TEModule->>Quantizer: Quantize weight
    end
    TEModule->>App: Return output

    Note over App,Quantizer: Backward Pass
    App->>FSDP: backward()
    FSDP->>QTensor: fsdp_pre_all_gather() (PRE_BACKWARD)
    QTensor->>QTensor: Set columnwise usage for dgrad
    QTensor->>FSDP: Return (sharded_tensors, metadata)
    FSDP->>FSDP: All-gather for backward
    FSDP->>QTensor: fsdp_post_all_gather()
    QTensor->>QTensor: Reconstruct with transpose data
    FSDP->>TEModule: Backward computation
    TEModule->>FSDP: Return gradients
    FSDP->>FSDP: Reduce-scatter gradients
    FSDP->>QTensor: Update sharded weights
Loading

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds FSDP2 (Fully Sharded Data Parallel 2) support for FP8 and MXFP8 quantized tensors in PyTorch Transformer Engine, enabling distributed training with FP8 mixed-precision.

Key Changes:

  • Implemented fsdp_pre_all_gather and fsdp_post_all_gather hooks for both Float8Tensor and MXFP8Tensor to handle FSDP weight sharding/gathering lifecycle
  • Added custom __torch_dispatch__ handlers for FSDP-required operations: aten.split.Tensor, aten.new_zeros, aten.as_strided, aten.copy_, and aten.slice.Tensor
  • Enhanced transpose caching logic to properly maintain transposed views through various tensor operations
  • Added training state-aware quantizer usage control (rowwise vs columnwise) based on forward/backward pass detection

Major Implementation Details:

  • FSDP2 integration distinguishes between forward/backward passes using TrainingState.PRE_BACKWARD to selectively gather only needed tensor representations
  • For MXFP8, operations validate 128-byte alignment constraints and fall back to dequantization when constraints aren't met
  • Transpose cache maintenance across splits, views, and resharding ensures performance optimization for Hopper/L40 architectures

Issues Found:
Multiple critical None-handling bugs exist in MXFP8 dispatch handlers where operations assume non-None data/scale tensors, which would cause AttributeError at runtime when certain usage flags are disabled.

Confidence Score: 3/5

  • This PR has several critical runtime issues that need resolution before merging, particularly around None-handling in MXFP8 tensor operations
  • Score reflects multiple logic bugs identified by previous reviewers (AttributeError, NameError, variable shadowing) that would cause runtime failures in MXFP8 operations. While Float8Tensor changes appear more robust, MXFP8Tensor has ~8-10 critical None-handling issues across split, slice, copy, and post_all_gather operations that need fixes
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py requires significant attention for None-handling fixes across all new dispatch handlers before this can safely merge

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather hooks, implements aten.split.Tensor, aten.new_zeros, and aten.as_strided handlers with transpose caching for FP8 tensors
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Implements FSDP2 support and multiple torch dispatch handlers (split, as_strided, copy_, slice, new_zeros) for MXFP8 tensors; contains several critical None-handling issues that need resolution

Sequence Diagram

sequenceDiagram
    participant FSDP2
    participant Float8Tensor/MXFP8Tensor
    participant Quantizer
    participant DeviceMesh

    Note over FSDP2: Forward Pass (weights needed)
    FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Float8Tensor/MXFP8Tensor->>Quantizer: check training_state & reshard_after_forward
    Quantizer->>Quantizer: set_usage(rowwise=True, columnwise=False)
    Float8Tensor/MXFP8Tensor->>FSDP2: return (sharded_data, metadata)
    FSDP2->>DeviceMesh: all_gather(sharded_data)
    DeviceMesh->>FSDP2: all_gather_outputs
    FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_post_all_gather(outputs, metadata, ...)
    Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: reconstruct full tensor
    Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: update_usage(rowwise=True)
    Float8Tensor/MXFP8Tensor->>FSDP2: return gathered_tensor
    
    Note over FSDP2: Compute forward pass
    
    Note over FSDP2: Backward Pass (gradients computed)
    FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Float8Tensor/MXFP8Tensor->>Quantizer: check training_state (PRE_BACKWARD)
    Quantizer->>Quantizer: set_usage(rowwise=False, columnwise=True)
    Float8Tensor/MXFP8Tensor->>FSDP2: return (transpose_data, metadata)
    FSDP2->>DeviceMesh: all_gather(transpose_data)
    DeviceMesh->>FSDP2: all_gather_outputs
    FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_post_all_gather(outputs, metadata, ...)
    Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: reconstruct with transpose
    Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: update_usage(columnwise=True)
    Float8Tensor/MXFP8Tensor->>FSDP2: return gathered_tensor
Loading

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR significantly enhances the FSDP2 test infrastructure for Transformer Engine by adding comprehensive support for FP8 mixed-precision training with distributed sharding.

Key Changes:

  • Expanded FP8 recipe support: added Float8CurrentScaling and MXFP8BlockScaling alongside existing DelayedScaling
  • Introduced flexible layer configuration system supporting 5 TE layer types (Linear, LayerNormLinear, LayerNormMLP, MultiheadAttention, TransformerLayer)
  • Added meta device initialization workflow for deferred parameter materialization after FSDP2 sharding
  • Implemented test_fp8_fsdp2_allgather() validation function to verify FP8 allgather correctness against manual FP32 allgather
  • Enhanced custom attribute save/restore logic to handle QuantizedTensor metadata correctly with FSDP2 DTensors
  • Replaced simple 3-layer network with configurable multi-layer architecture supporting both reshard_after_forward=True/False test cases

The test file is well-structured with clear separation of concerns: model initialization, FSDP2 setup, training loop, and validation logic.

Confidence Score: 5/5

  • This PR is safe to merge with high confidence - the changes are well-tested, properly structured, and add comprehensive FSDP2 support.
  • Score reflects thorough implementation with proper error handling, comprehensive test coverage of multiple FP8 recipes and layer types, correct FSDP2 integration patterns (save/restore custom attrs, DTensor handling), and validation logic to verify FP8 allgather correctness. The code follows established patterns and includes clear documentation.
  • No files require special attention - the test file is comprehensive and correctly implements FSDP2 FP8 support.

Important Files Changed

File Analysis

Filename Score Overview
tests/pytorch/distributed/run_fsdp2_model.py 5/5 Comprehensive FSDP2 test script adding support for multiple FP8 recipes, flexible layer configurations, meta device initialization, and FP8 allgather validation

Sequence Diagram

sequenceDiagram
    participant Main as Main Process
    participant Init as Model Init
    participant FSDP as FSDP2 Sharding
    participant Train as Training Loop
    participant Test as FP8 Test

    Main->>Main: Parse args & setup distributed
    Main->>Init: Create FP8 recipe (delayed/current/mx_fp8)
    
    alt FP8 Init Enabled
        Init->>Init: fp8_model_init(recipe)
    else FP8 Init Disabled
        Init->>Init: nullcontext()
    end
    
    Init->>Init: init_te_model(config)
    Note over Init: Create model on meta/cuda device
    
    Init->>FSDP: save_custom_attrs(model)
    Note over FSDP: Save QuantizedTensor metadata
    
    FSDP->>FSDP: get_device_mesh(world_size, sharding_dims)
    Note over FSDP: Setup FSDP or HSDP mesh
    
    FSDP->>FSDP: shard_model_with_fsdp2(model, mesh)
    Note over FSDP: Apply fully_shard to children & root
    
    FSDP->>FSDP: restore_custom_attrs(model, custom_attrs)
    Note over FSDP: Restore FP8 metadata to DTensors
    
    alt Meta Device Init
        FSDP->>FSDP: reset_parameters()
        Note over FSDP: Materialize sharded params on cuda
    end
    
    FSDP->>Train: Create optimizer
    
    loop For each iteration
        Train->>Train: Generate input & target
        Train->>Train: Forward with te.autocast(recipe)
        Train->>Train: Compute loss
        Train->>Train: Backward pass
        Train->>Train: Optimizer step
    end
    
    alt FP8 Init Enabled
        Train->>Test: test_fp8_fsdp2_allgather(model)
        Test->>Test: Manual FP32 allgather
        Test->>Test: FSDP2 FP8 allgather (unshard)
        Test->>Test: Validate both match
        Test->>Test: Reshard model
    end
    
    Main->>Main: Destroy process group
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds FSDP2 (Fully Sharded Data Parallel v2) support to Transformer Engine by enabling DTensor parameter handling in the base module.

Key Changes:

  • Modified register_parameter() to prevent overwriting FP8-specific metadata when FSDP2 re-registers parameters as DTensors
  • Enhanced reset_parameters() to detect and handle DTensor parameters by operating on their local tensors
  • Added device mesh integration for Float8CurrentScalingQuantizer to configure amax reduction groups for distributed training
  • Implemented proper DTensor reconstruction after meta-device materialization
  • Ensured quantized local tensors are correctly wrapped back into DTensor parameters

Integration Points:

  • DTensor detection via isinstance(param, DTensor) check
  • Local tensor extraction and manipulation via param._local_tensor
  • Device mesh group configuration for FP8 scaling synchronization across shards
  • Parameter wrapping preserves both DTensor structure and FP8 quantization

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations for edge cases in DTensor handling
  • The implementation correctly handles DTensor parameter registration and initialization. The logic properly distinguishes between DTensor and regular tensors, extracts local tensors for processing, and reconstructs DTensors with appropriate device mesh configuration. The amax reduction group setup for Float8CurrentScalingQuantizer is correctly conditioned on both DTensor type and quantizer type. However, the score is 4 instead of 5 because: (1) the high-precision init value methods are attached to local tensors which relies on DTensor's attribute delegation pattern, and (2) there's no explicit validation that dtensor_param maintains valid device_mesh/placements attributes throughout the flow, though the logic appears sound
  • No files require special attention beyond standard FSDP2 testing with FP8 quantization enabled

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/base.py 4/5 Added FSDP2 DTensor support in parameter registration and reset, including proper handling of local tensors, device mesh configuration for FP8 quantization, and parameter wrapping

Sequence Diagram

sequenceDiagram
    participant FSDP2 as FSDP2
    participant Module as TransformerEngineBaseModule
    participant ResetParams as reset_parameters()
    participant Quantizer as Float8CurrentScalingQuantizer
    participant DTensor as DTensor

    FSDP2->>Module: register_parameter(name, DTensor)
    Note over Module: Check if param_init_meta exists<br/>Only initialize once to preserve FP8 kwargs
    Module->>Module: Store param_init_meta[name]
    
    FSDP2->>ResetParams: Trigger parameter initialization
    ResetParams->>ResetParams: Check if param is DTensor
    ResetParams->>DTensor: Extract _local_tensor
    
    alt Parameter on meta device
        ResetParams->>ResetParams: Create empty_like on cuda
        ResetParams->>DTensor: Reconstruct DTensor.from_local()<br/>with device_mesh & placements
    end
    
    ResetParams->>ResetParams: Apply init_fn to local tensor
    
    alt FP8 quantization enabled
        ResetParams->>Quantizer: Configure quantizer settings
        alt Is DTensor && Float8CurrentScaling
            ResetParams->>DTensor: Get device_mesh
            ResetParams->>Quantizer: Set amax_reduction_group<br/>from device_mesh.get_group()
            ResetParams->>Quantizer: Enable with_amax_reduction
        end
        ResetParams->>Quantizer: Quantize local tensor
        Quantizer-->>ResetParams: Return QuantizedTensor
    end
    
    ResetParams->>DTensor: Update _local_tensor with quantized tensor
    ResetParams->>DTensor: Wrap as nn.Parameter
    ResetParams->>Module: setattr(name, DTensor parameter)
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Enables FSDP2 training with FP8/MXFP8 initialized weights by implementing custom allgather hooks (fsdp_pre_all_gather and fsdp_post_all_gather) that serialize FP8 tensors to uint8 for distributed communication and reconstruct them post-allgather.

Key Changes:

  • FP8 Allgather Support: Float8Tensor and MXFP8Tensor now implement FSDP2 hooks that return uint8 data with metadata (scale_inv, dtype, quantizer) for allgather, enabling FP8 communication instead of high-precision
  • Selective Usage Based on Training State: Pre-allgather hooks optimize memory by gathering only rowwise data for forward pass and columnwise data for backward pass when reshard_after_forward=True
  • DTensor Integration: TransformerEngineBaseModule.reset_parameters() now handles FSDP2's DTensor parameters by operating on _local_tensor and preserving FP8 metadata across parameter re-registration
  • Transpose Cache Management: Enhanced __torch_dispatch__ handlers for split/view/new_zeros/as_strided ops to maintain transpose caches for both data and data_transpose, improving performance
  • Amax Reduction Setup: Quantizers are configured with appropriate reduction groups for synchronized scale updates across FSDP shards

Issues Found:

  • Potential tensor unpacking bug in mxfp8_tensor.py:613-617 where both [:2] and [-2:] slicing could select duplicate tensors if validation fails

Confidence Score: 4/5

  • This PR is largely safe to merge with one logical issue that needs verification in edge cases
  • The implementation is well-structured and addresses a significant feature gap (FSDP2 support for FP8 weights). The core allgather hook logic is sound and properly handles the forward/backward pass distinction. However, there's a potential edge-case bug in MXFP8Tensor's fsdp_post_all_gather where tensor unpacking could fail if the tuple length doesn't match usage flags, though this is unlikely in normal operation since the pre/post hooks are paired
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py - verify the tensor unpacking logic at lines 613-617 handles all edge cases correctly

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather hooks, implements FP8 allgather by returning uint8 data with metadata for reconstruction, enhances __torch_dispatch__ to handle transpose caching for split/view/new_zeros/as_strided ops
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Adds FSDP2 allgather hooks for MXFP8 tensors with selective rowwise/columnwise data gathering based on training state, implements torch dispatch handlers for split/as_strided/copy_/slice/new_zeros ops with MXFP8 block scaling constraints, has potential tensor unpacking issue in fsdp_post_all_gather
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with LRU caching to retrieve FSDP state from modules or their closest FSDP parent
transformer_engine/pytorch/module/base.py 4/5 Updates reset_parameters to handle DTensor (FSDP2) by operating on _local_tensor, preserves FP8 metadata during FSDP2's re-registration of parameters as DTensors, sets up amax reduction groups for DTensor quantizers

Sequence Diagram

sequenceDiagram
    participant FSDP2 as FSDP2
    participant Float8Tensor as Float8Tensor/MXFP8Tensor
    participant Quantizer as Quantizer
    participant Module as TransformerEngineModule
    
    Note over FSDP2,Module: Forward Pass (or Backward if reshard_after_forward=True)
    
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, orig_size, module, ...)
    Float8Tensor->>Module: _get_module_fsdp_state(module)
    Module-->>Float8Tensor: fsdp_state
    
    Float8Tensor->>Quantizer: copy()
    Quantizer-->>Float8Tensor: quantizer_copy
    
    alt reshard_after_forward=True
        Float8Tensor->>Float8Tensor: Determine forward vs backward from training_state
        Float8Tensor->>Quantizer: set_usage(rowwise=!is_backward, columnwise=is_backward)
        Note over Float8Tensor: Pack only needed data based on pass direction
    else reshard_after_forward=False
        Note over Float8Tensor: Pack both rowwise and columnwise if needed
    end
    
    Float8Tensor-->>FSDP2: (sharded_uint8_tensors, metadata)
    
    FSDP2->>FSDP2: AllGather uint8 tensors across ranks
    
    FSDP2->>Float8Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, param_dtype, out)
    
    Float8Tensor->>Float8Tensor: Unpack all_gather_outputs and metadata
    
    alt out exists
        Float8Tensor->>Float8Tensor: update_usage() on existing tensor
    else out is None
        Float8Tensor->>Float8Tensor: Construct new Float8Tensor/MXFP8Tensor
        Float8Tensor->>Float8Tensor: update_usage() on new tensor
    end
    
    Float8Tensor-->>FSDP2: (reconstructed_fp8_tensor, all_gather_outputs)
    
    Note over FSDP2,Module: Tensor ready for forward/backward computation
Loading

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR makes focused changes to quantized_tensor.py to improve FSDP2 compatibility:

  • Recursive list handling: Added support for recursively updating lists of tensors in in-place operations (lines 436-439). This handles operations like split that return multiple tensors, ensuring QuantizedTensors within lists are properly updated.

  • Simplified make_like method: Removed the data parameter from the base class implementation (lines 493-506). The method now focuses solely on creating views of tensors. This change is safe because:

    • Subclasses like Float8Tensor override this method and still support the data parameter for backward compatibility
    • The base class docstring now correctly reflects that the method is "intended to create view of tensors"
    • Existing usages with data= parameter are handled by the overridden methods in subclasses

These are minimal, well-scoped changes that support the broader FSDP2 integration without breaking existing functionality.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk
  • The changes are minimal and focused, with only two small modifications to quantized_tensor.py. The recursive list handling is a straightforward addition that improves robustness. The make_like signature change is safe because subclasses override the method and maintain backward compatibility. No issues found that would impact correctness or introduce bugs.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/quantized_tensor.py 4/5 Added recursive list handling for in-place operations and simplified make_like method by removing data parameter. Changes are minimal and focused on improving FSDP2 compatibility.

Sequence Diagram

sequenceDiagram
    participant FSDP as FSDP2 Framework
    participant PreHook as fsdp_pre_allgather
    participant QT as QuantizedTensor
    participant PostHook as fsdp_post_allgather
    
    Note over FSDP: Forward/Backward Pass Begins
    FSDP->>PreHook: Call pre_allgather hook
    PreHook->>QT: Extract uint8 data + metadata
    Note over QT: For FP8: extract _data tensor<br/>For MXFP8: extract rowwise/columnwise data
    QT-->>PreHook: Return (uint8_tensors, metadata)
    PreHook-->>FSDP: Return allgather input
    
    Note over FSDP: Perform AllGather on uint8 data
    
    FSDP->>PostHook: Call post_allgather hook
    PostHook->>QT: Reconstruct from allgathered data
    Note over QT: Rebuild Float8/MXFP8 tensor<br/>from uint8 + metadata
    QT-->>PostHook: Return reconstructed tensor
    PostHook-->>FSDP: Return full tensor
    
    Note over FSDP: Continue computation with full tensor
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Enables FSDP2 distributed training with FP8-initialized weights by implementing allgather hooks, torch dispatch operations, and DTensor support for TE quantized tensors.

Key Changes:

  • Implements fsdp_pre_all_gather and fsdp_post_all_gather hooks for Float8Tensor and MXFP8Tensor to enable 8-bit weight allgather (instead of high-precision)
  • Adds torch dispatch support for FSDP2 tensor operations: split, copy_, slice, view, as_strided, new_zeros
  • Updates reset_parameters in TransformerEngineBaseModule to handle DTensor for deferred initialization (meta device)
  • Fixes optimizer weight updates by recursively handling lists of quantized tensors in in-place operations
  • Moves quantizer usage validation from forward to backward pass to support phase-aware allgather
  • Configures amax reduction groups for current scaling quantizer to synchronize scale inverses across FSDP shards
  • Comprehensive test coverage for multiple TE layers with delayed/current/MX_FP8 scaling recipes

Memory Impact:

  • FP8 per-tensor scaling reduces memory footprint by ~50% vs BF16 on Blackwell (as expected)
  • MXFP8 block scaling maintains similar memory to BF16 due to rowwise+columnwise storage requirements

Confidence Score: 4/5

  • This PR is mostly safe to merge with one critical bug fix needed in MXFP8 shape handling
  • Score reflects solid implementation with comprehensive test coverage, but deducted 1 point due to critical bug in mxfp8_tensor.py:658 where both rowwise/columnwise data can be None causing AttributeError, and minor concerns about LRU cache causing potential memory leaks
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py:658 requires fix for None handling

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 allgather hooks and torch dispatch ops (view, split, copy, slice, as_strided, new_zeros) to support 8-bit weight sharding. Implements current scaling quantizer sync for FSDP weight updates. Potential issue with shape handling in line 658.
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Implements FSDP2 support with rowwise/columnwise data handling for block-scaled FP8. Adds dispatch ops (split, copy, slice, as_strided, new_zeros). Critical bug at line 658 where both data tensors can be None causing AttributeError.
transformer_engine/pytorch/module/base.py 4/5 Updates reset_parameters to handle DTensor (FSDP2 deferred init) by quantizing local tensor and reconstructing DTensor. Adds amax_reduction_group configuration for current scaling. Guard prevents metadata loss during DTensor conversion.
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with LRU cache to find FSDP state for allgather hooks. Cache could cause memory leaks but likely acceptable given module stability during training.
transformer_engine/pytorch/quantized_tensor.py 5/5 Fixes in-place ops to recursively handle lists of tensors (optimizer sends batched updates). Removes data parameter from make_like to avoid confusion between view creation and data initialization.

Sequence Diagram

sequenceDiagram
    participant FSDP as FSDP2
    participant QT as QuantizedTensor (FP8/MXFP8)
    participant Helper as _get_module_fsdp_state
    participant Optimizer as Optimizer
    
    Note over FSDP,QT: Forward Pass - Weight Allgather
    FSDP->>QT: fsdp_pre_all_gather(module, mesh, ...)
    QT->>Helper: Get FSDP state to determine phase
    Helper-->>QT: training_state, reshard_after_forward
    QT->>QT: Set quantizer.rowwise_usage=True, columnwise=False
    QT-->>FSDP: (uint8_data, ...), metadata
    FSDP->>FSDP: All-gather uint8 shards
    FSDP->>QT: fsdp_post_all_gather(gathered_outputs, metadata)
    QT->>QT: Reconstruct FP8 tensor with rowwise usage
    QT-->>FSDP: Allgathered FP8 weight
    
    Note over FSDP,QT: Forward Pass Compute
    FSDP->>QT: Forward computation with FP8 weights
    
    Note over FSDP,QT: Backward Pass - Weight Allgather (if reshard_after_forward)
    FSDP->>QT: fsdp_pre_all_gather(module, mesh, ...)
    QT->>Helper: Get FSDP state
    Helper-->>QT: training_state=PRE_BACKWARD
    QT->>QT: Set quantizer.rowwise=False, columnwise_usage=True
    QT-->>FSDP: (uint8_data_transpose, ...), metadata
    FSDP->>FSDP: All-gather transpose/columnwise shards
    FSDP->>QT: fsdp_post_all_gather(gathered_outputs, metadata)
    QT->>QT: Reconstruct FP8 tensor with columnwise usage
    QT-->>FSDP: Allgathered FP8 weight
    
    Note over FSDP,Optimizer: Gradient Computation & Weight Update
    FSDP->>FSDP: Compute gradients, reduce-scatter
    Optimizer->>QT: In-place update (lerp on list of tensors)
    QT->>QT: Dequantize, apply op, quantize with amax reduction
    Note over QT: Amax synchronized across shards<br/>for current scaling
Loading

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

return inp, handle


@lru_cache
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: @lru_cache on instance methods can cause memory leaks since module references stay cached. The FSDP state itself is stateful and mutated during training, so caching based on module identity could potentially return stale references if modules are recreated. Consider @lru_cache(maxsize=128) with explicit cache invalidation or verify modules are never recreated during training.

columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: if both rowwise_data and columnwise_data are None (when both usage flags are False), accessing .shape raises AttributeError

Suggested change
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
shape=rowwise_data.shape if rowwise_data is not None else (columnwise_data.shape if columnwise_data is not None else torch.Size([0])),

Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables FSDP2 training with FP8-initialized weights by implementing custom allgather hooks and torch dispatch handlers. The implementation addresses three key issues: memory footprint with FP8 weights, correct weight updates during training, and efficient 8-bit weight allgather.

Key Changes:

  • Implements fsdp_pre_all_gather and fsdp_post_all_gather hooks for Float8Tensor and MXFP8Tensor to handle FP8/MXFP8 allgather using uint8 data
  • Adds torch dispatch handlers for split, slice, copy, new_zeros, as_strided, and view operations needed by FSDP2
  • Uses FSDP state to detect forward vs backward pass and set appropriate rowwise/columnwise quantizer usage
  • Sets amax_reduction_group for current scaling quantization to synchronize scale inverses across shards
  • Updates DTensor parameters correctly during deferred initialization (meta device)
  • Moves quantizer usage validation from layer forward() to _apply_forward/backward functions to accommodate FSDP2's separate allgather for forward/backward

Critical Issues Found:

  • mxfp8_tensor.py:502 and mxfp8_tensor.py:389 have potential AttributeError when accessing .shape on tensors that can be None (when neither rowwise nor columnwise data exists)

Confidence Score: 3/5

  • This PR introduces critical bugs that will cause runtime failures in edge cases, but the core FSDP2 integration logic is sound
  • Score of 3 reflects two critical logic errors in MXFP8Tensor dispatch handlers (lines 389 and 502) that will cause AttributeError when accessing .shape on None values. These bugs occur when quantizer has neither rowwise nor columnwise usage enabled, which may be rare but is not prevented. The rest of the implementation is well-designed with proper handling of forward/backward distinction, amax reduction groups, and DTensor support
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py lines 389 and 502 require immediate fixes to handle None tensor data

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/mxfp8_tensor.py 2/5 Adds FSDP2 torch dispatch handlers for split, slice, copy, new_zeros, as_strided, and view ops. Contains critical bug: line 502 accesses .shape on out_data[0] which can be None when neither rowwise nor columnwise data exists
transformer_engine/pytorch/tensor/float8_tensor.py 3/5 Adds FSDP2 allgather hooks and torch dispatch handlers for various ops. Implements rowwise/columnwise usage tracking for forward/backward passes. Generally well-structured but relies on cached FSDP state lookup
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with @lru_cache to find FSDP state for modules. Cache is appropriate since it stores reference to mutable state object
transformer_engine/pytorch/module/base.py 4/5 Updates reset_parameters to handle DTensor params for FSDP2 deferred init, sets amax reduction group for current scaling quantization. Logic is sound

Sequence Diagram

sequenceDiagram
    participant FSDP2
    participant Float8Tensor
    participant MXFP8Tensor
    participant Quantizer
    participant TE_Module
    
    Note over FSDP2,TE_Module: Forward Pass
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Float8Tensor->>Float8Tensor: Set amax_reduction_group for current scaling
    Float8Tensor->>Float8Tensor: Get FSDP state, check reshard_after_forward
    Float8Tensor->>Quantizer: copy() and set_usage(rowwise=True)
    Float8Tensor-->>FSDP2: Return (uint8 data, metadata)
    FSDP2->>FSDP2: All-gather uint8 data across shards
    FSDP2->>Float8Tensor: fsdp_post_all_gather(gathered_data, metadata)
    Float8Tensor->>Float8Tensor: Reconstruct with rowwise usage
    Float8Tensor-->>FSDP2: Return reconstructed Float8Tensor
    FSDP2->>TE_Module: forward(input)
    TE_Module->>TE_Module: Validate rowwise usage in _apply_forward
    TE_Module-->>FSDP2: output
    
    Note over FSDP2,TE_Module: Backward Pass (if reshard_after_forward)
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Float8Tensor->>Float8Tensor: Get FSDP state, detect PRE_BACKWARD
    Float8Tensor->>Quantizer: copy() and set_usage(columnwise=True)
    Float8Tensor-->>FSDP2: Return (uint8 data, metadata)
    FSDP2->>FSDP2: All-gather uint8 data across shards
    FSDP2->>Float8Tensor: fsdp_post_all_gather(gathered_data, metadata)
    Float8Tensor->>Float8Tensor: Reconstruct with columnwise usage
    Float8Tensor-->>FSDP2: Return reconstructed Float8Tensor
    TE_Module->>TE_Module: Validate columnwise usage in _apply_backward
    
    Note over FSDP2,TE_Module: Weight Initialization (Meta Device)
    TE_Module->>TE_Module: reset_parameters() with DTensor
    TE_Module->>Quantizer: Set amax_reduction_group from mesh
    TE_Module->>Float8Tensor: Initialize with FP8 quantization
    TE_Module->>TE_Module: Update DTensor.local_tensor
Loading

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

)
out_data.append(scale_inv_out)
return MXFP8Tensor(
shape=out_data[0].shape,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: AttributeError if both _rowwise_data and _columnwise_data are None - accessing .shape on out_data[0] will fail when neither usage flag is set

Suggested change
shape=out_data[0].shape,
shape=out_data[0].shape if out_data[0] is not None else out_data[1].shape,

out_data.append(scale_inv_out)
return [
MXFP8Tensor(
shape=splitted_tensor_data[0].size(),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: AttributeError if splitted_tensor_data[0] is None - happens when tensor._rowwise_data is None before splitting

Suggested change
shape=splitted_tensor_data[0].size(),
shape=splitted_tensor_data[0].size() if splitted_tensor_data[0] is not None else splitted_tensor_data[1].size(),

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables end-to-end FSDP2 training for PyTorch models with Transformer Engine layers initialized with FP8/MXFP8 weights, solving three critical issues: memory footprint problems with FP8-initialized weights, weight update correctness during training, and enabling 8-bit allgather instead of high-precision.

Key Changes:

  • FSDP2 Allgather Hooks: Implements fsdp_pre_all_gather and fsdp_post_all_gather methods for Float8Tensor and MXFP8Tensor to support 8-bit weight allgather by returning uint8 data and metadata for reconstruction
  • Torch Dispatch Operations: Adds handlers for split, new_zeros, as_strided, copy_, slice, and view operations to support FSDP2 sharding and resharding of quantized tensors
  • FSDP State Management: Introduces _get_module_fsdp_state helper with LRU cache to determine forward/backward pass and reshard_after_forward configuration, enabling proper rowwise/columnwise usage selection
  • Current Scaling Synchronization: Sets amax reduction group in quantizers during allgather to ensure all weight shards share the same scale inverse after optimizer updates
  • DTensor Support: Updates reset_parameters in base module to handle DTensor parameters for FSDP2 deferred initialization with proper quantizer configuration
  • Quantized Tensor Fixes: Fixes in-place operations to handle lists of tensors (for optimizer lerp operations) and removes incorrect data parameter from make_like API
  • Usage Validation Refactoring: Moves quantizer usage validation from layer forward to forward/backward functions, and removes unnecessary quantizer updates when weights are already quantized

Memory Impact: FP8 per-tensor quantization reduces memory by 50% vs BF16 on Blackwell. MXFP8 has similar memory footprint to BF16 due to needing both rowwise/columnwise representations.

Test Coverage: Comprehensive tests cover delayed scaling, current scaling, and MX_FP8 block scaling recipes with various layer types (Linear, LayerNormLinear, TransformerLayer) and both FSDP/HSDP configurations.

Confidence Score: 4/5

  • Safe to merge with minor considerations - addresses long-standing FSDP2+FP8 issues with comprehensive implementation
  • Score of 4 reflects solid implementation with extensive test coverage addressing critical functionality gaps. The changes are well-architected with proper separation between FP8/MXFP8 tensor handling, FSDP2 hooks, and torch dispatch operations. Previous syntax errors in mxfp8_tensor.py mentioned in earlier comments have been fixed. Main concerns are: (1) LRU cache on _get_module_fsdp_state could retain module references indefinitely though the return value is a mutable state reference, (2) complex logic for determining forward/backward pass and reshard_after_forward could benefit from additional inline documentation, (3) MXFP8 view operation intentionally falls back to dequantize path with warning when flattening inner dimension. The PR resolves three critical GitHub issues (#1688, #401, #1135, #1188) and includes validation tests.
  • Pay close attention to transformer_engine/pytorch/tensor/float8_tensor.py and transformer_engine/pytorch/tensor/mxfp8_tensor.py for the complex torch dispatch logic and allgather hooks

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with @lru_cache to find FSDP state for modules - enables determining forward/backward pass during allgather
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Implements FSDP2 hooks (fsdp_pre_all_gather, fsdp_post_all_gather) and torch dispatch for split/new_zeros/as_strided/copy operations to support 8-bit allgather
transformer_engine/pytorch/tensor/mxfp8_tensor.py 4/5 Implements FSDP2 hooks and torch dispatch for MXFP8 tensors with rowwise/columnwise data handling - includes split/as_strided/copy/slice operations
transformer_engine/pytorch/quantized_tensor.py 4/5 Fixes in-place ops to handle lists of tensors (for optimizer updates) and removes data parameter from make_like to fix view semantics
transformer_engine/pytorch/module/base.py 4/5 Adds DTensor support in reset_parameters for FSDP2 deferred initialization, handles amax reduction group setup for current scaling quantization
transformer_engine/pytorch/module/linear.py 4/5 Removes quantizer updates when weight is already quantized, moves columnwise usage validation from forward to backward function

Sequence Diagram

sequenceDiagram
    participant User
    participant FSDP2
    participant TEModule as TE Module
    participant Float8Tensor
    participant Quantizer
    participant AllGather as FSDP AllGather

    User->>FSDP2: Initialize model with fp8_model_init
    FSDP2->>TEModule: Create FP8/MXFP8 weight shards
    TEModule->>Float8Tensor: Initialize quantized weights
    Float8Tensor->>Quantizer: Setup amax reduction group
    
    User->>FSDP2: Start training iteration (forward pass)
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(module, mesh)
    Float8Tensor->>TEModule: Get FSDP state via _get_module_fsdp_state
    Float8Tensor->>Quantizer: Set rowwise usage for forward
    Float8Tensor-->>FSDP2: Return (uint8_data,), metadata
    
    FSDP2->>AllGather: AllGather uint8 data across shards
    AllGather-->>FSDP2: Gathered uint8 data
    
    FSDP2->>Float8Tensor: fsdp_post_all_gather(outputs, metadata)
    Float8Tensor->>Float8Tensor: Reconstruct Float8Tensor with gathered data
    Float8Tensor->>Float8Tensor: update_usage(rowwise=True)
    Float8Tensor-->>FSDP2: Return reconstructed weight tensor
    
    FSDP2->>TEModule: Forward pass with gathered weights
    TEModule->>TEModule: Compute activations
    
    alt reshard_after_forward=True
        FSDP2->>FSDP2: Reshard weights after forward
    end
    
    User->>FSDP2: Backward pass
    
    alt reshard_after_forward=True
        FSDP2->>Float8Tensor: fsdp_pre_all_gather (backward)
        Float8Tensor->>Quantizer: Set columnwise usage for backward
        Float8Tensor-->>FSDP2: Return appropriate data for backward
        FSDP2->>AllGather: AllGather for backward pass
        FSDP2->>Float8Tensor: fsdp_post_all_gather
        Float8Tensor-->>FSDP2: Reconstructed weight for backward
    end
    
    FSDP2->>TEModule: Backward pass with weights
    TEModule->>TEModule: Compute gradients
    FSDP2->>FSDP2: ReduceScatter gradients
    
    User->>FSDP2: Optimizer step
    FSDP2->>Float8Tensor: Update weight shards (via lerp/copy_)
    Float8Tensor->>Quantizer: Sync amax across shards
    Quantizer->>Quantizer: AllReduce amax for single scale_inv
Loading

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants